/*
 * contrib/pg_trgm/trgm_op.c
 */
#include "postgres.h"
#include "knl/knl_variable.h"

#include <ctype.h>

#include "trgm.h"

#include "catalog/pg_type.h"
#include "tsearch/ts_locale.h"

PG_MODULE_MAGIC;

float4 trgm_limit = 0.3f;

PG_FUNCTION_INFO_V1(set_limit);
extern "C" Datum set_limit(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(show_limit);
extern "C" Datum show_limit(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(show_trgm);
extern "C" Datum show_trgm(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(similarity);
extern "C" Datum similarity(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(similarity_dist);
extern "C" Datum similarity_dist(PG_FUNCTION_ARGS);

PG_FUNCTION_INFO_V1(similarity_op);
extern "C" Datum similarity_op(PG_FUNCTION_ARGS);

Datum set_limit(PG_FUNCTION_ARGS)
{
    float4 nlimit = PG_GETARG_FLOAT4(0);

    if (nlimit < 0 || nlimit > 1.0)
        elog(ERROR, "wrong limit, should be between 0 and 1");
    trgm_limit = nlimit;
    PG_RETURN_FLOAT4(trgm_limit);
}

Datum show_limit(PG_FUNCTION_ARGS)
{
    PG_RETURN_FLOAT4(trgm_limit);
}

static int comp_trgm(const void* a, const void* b)
{
    return CMPTRGM(a, b);
}

static int unique_array(trgm* a, int len)
{
    trgm *curend, *tmp;

    curend = tmp = a;
    while (tmp - a < len) {
        if (CMPTRGM(tmp, curend)) {
            curend++;
            CPTRGM(curend, tmp);
            tmp++;
        } else {
            tmp++;
        }
    }

    return curend + 1 - a;
}

#ifdef KEEPONLYALNUM
#define iswordchr(c) (t_isalpha(c) || t_isdigit(c))
#else
#define iswordchr(c) (!t_isspace(c))
#endif

/*
 * Finds first word in string, returns pointer to the word,
 * endword points to the character after word
 */
static char* find_word(char* str, int lenstr, char** endword, int* charlen)
{
    char* beginword = str;

    while (beginword - str < lenstr && !iswordchr(beginword))
        beginword += pg_mblen(beginword);

    if (beginword - str >= lenstr)
        return NULL;

    *endword = beginword;
    *charlen = 0;
    while (*endword - str < lenstr && iswordchr(*endword)) {
        *endword += pg_mblen(*endword);
        (*charlen)++;
    }

    return beginword;
}

#ifdef USE_WIDE_UPPER_LOWER
static void cnt_trigram(trgm* tptr, char* str, int bytelen)
{
    if (bytelen == 3) {
        CPTRGM(tptr, str);
    } else {
        pg_crc32 crc;

        INIT_CRC32(crc);
        COMP_CRC32(crc, str, bytelen);
        FIN_CRC32(crc);

        /*
         * use only 3 upper bytes from crc, hope, it's good enough hashing
         */
        CPTRGM(tptr, &crc);
    }
}
#endif

/*
 * Adds trigrams from words (already padded).
 */
static trgm* make_trigrams(trgm* tptr, char* str, int bytelen, int charlen)
{
    char* ptr = str;

    if (charlen < 3) {
        return tptr;
    }

#ifdef USE_WIDE_UPPER_LOWER
    if (pg_database_encoding_max_length() > 1) {
        int lenfirst = pg_mblen(str), lenmiddle = pg_mblen(str + lenfirst),
            lenlast = pg_mblen(str + lenfirst + lenmiddle);

        while ((ptr - str) + lenfirst + lenmiddle + lenlast <= bytelen) {
            cnt_trigram(tptr, ptr, lenfirst + lenmiddle + lenlast);

            ptr += lenfirst;
            tptr++;

            lenfirst = lenmiddle;
            lenmiddle = lenlast;
            lenlast = pg_mblen(ptr + lenfirst + lenmiddle);
        }
    } else
#endif
    {
        Assert(bytelen == charlen);

        while (ptr - str < bytelen - 2 /* number of trigrams = strlen - 2 */) {
            CPTRGM(tptr, ptr);
            ptr++;
            tptr++;
        }
    }

    return tptr;
}

TRGM* generate_trgm(char* str, int slen)
{
    TRGM* trg = NULL;
    char* buf = NULL;
    trgm* tptr = NULL;
    int len, charlen, bytelen;
    char *bword = NULL;
    char *eword = NULL;

    trg = (TRGM*)palloc(TRGMHDRSIZE + sizeof(trgm) * (slen / 2 + 1) * 3);
    trg->flag = ARRKEY;
    SET_VARSIZE(trg, TRGMHDRSIZE);

    if (slen + LPADDING + RPADDING < 3 || slen == 0) {
        return trg;
    }

    tptr = GETARR(trg);

    buf = (char*)palloc(sizeof(char) * (slen + 4));

    if (LPADDING > 0) {
        *buf = ' ';
        if (LPADDING > 1) {
            *(buf + 1) = ' ';
        }
    }

    eword = str;
    while ((bword = find_word(eword, slen - (eword - str), &eword, &charlen)) != NULL) {
#ifdef IGNORECASE
        bword = lowerstr_with_len(bword, eword - bword);
        bytelen = strlen(bword);
#else
        bytelen = eword - bword;
#endif

        memcpy(buf + LPADDING, bword, bytelen);

#ifdef IGNORECASE
        pfree(bword);
#endif
        buf[LPADDING + bytelen] = ' ';
        buf[LPADDING + bytelen + 1] = ' ';

        /*
         * count trigrams
         */
        tptr = make_trigrams(tptr, buf, bytelen + LPADDING + RPADDING, charlen + LPADDING + RPADDING);
    }

    pfree(buf);

    if ((len = tptr - GETARR(trg)) == 0)
        return trg;

    if (len > 0) {
        qsort((void*)GETARR(trg), len, sizeof(trgm), comp_trgm);
        len = unique_array(GETARR(trg), len);
    }

    SET_VARSIZE(trg, CALCGTSIZE(ARRKEY, len));

    return trg;
}

/*
 * Extract the next non-wildcard part of a search string, ie, a word bounded
 * by '_' or '%' meta-characters, non-word characters or string end.
 *
 * str: source string, of length lenstr bytes (need not be null-terminated)
 * buf: where to return the substring (must be long enough)
 * *bytelen: receives byte length of the found substring
 * *charlen: receives character length of the found substring
 *
 * Returns pointer to end+1 of the found substring in the source string.
 * Returns NULL if no word found (in which case buf, bytelen, charlen not set)
 *
 * If the found word is bounded by non-word characters or string boundaries
 * then this function will include corresponding padding spaces into buf.
 */
static const char* get_wildcard_part(const char* str, int lenstr, char* buf, int* bytelen, int* charlen)
{
    const char* beginword = str;
    const char* endword = NULL;
    char* s = buf;
    bool in_leading_wildcard_meta = false;
    bool in_trailing_wildcard_meta = false;
    bool in_escape = false;
    int clen;

    /*
     * Find the first word character, remembering whether preceding character
     * was wildcard meta-character.  Note that the in_escape state persists
     * from this loop to the next one, since we may exit at a word character
     * that is in_escape.
     */
    while (beginword - str < lenstr) {
        if (in_escape) {
            if (iswordchr(beginword))
                break;
            in_escape = false;
            in_leading_wildcard_meta = false;
        } else {
            if (ISESCAPECHAR(beginword)) {
                in_escape = true;
            } else if (ISWILDCARDCHAR(beginword)) {
                in_leading_wildcard_meta = true;
            } else if (iswordchr(beginword)) {
                break;
            } else {
                in_leading_wildcard_meta = false;
            }
        }
        beginword += pg_mblen(beginword);
    }

    /*
     * Handle string end.
     */
    if (beginword - str >= lenstr)
        return NULL;

    /*
     * Add left padding spaces if preceding character wasn't wildcard
     * meta-character.
     */
    *charlen = 0;
    if (!in_leading_wildcard_meta) {
        if (LPADDING > 0) {
            *s++ = ' ';
            (*charlen)++;
            if (LPADDING > 1) {
                *s++ = ' ';
                (*charlen)++;
            }
        }
    }

    /*
     * Copy data into buf until wildcard meta-character, non-word character or
     * string boundary.  Strip escapes during copy.
     */
    endword = beginword;
    while (endword - str < lenstr) {
        clen = pg_mblen(endword);
        if (in_escape) {
            if (iswordchr(endword)) {
                memcpy(s, endword, clen);
                (*charlen)++;
                s += clen;
            } else {
                /*
                 * Back up endword to the escape character when stopping at
                 * an escaped char, so that subsequent get_wildcard_part will
                 * restart from the escape character.  We assume here that
                 * escape chars are single-byte.
                 */
                endword--;
                break;
            }
            in_escape = false;
        } else {
            if (ISESCAPECHAR(endword)) {
                in_escape = true;
            } else if (ISWILDCARDCHAR(endword)) {
                in_trailing_wildcard_meta = true;
                break;
            } else if (iswordchr(endword)) {
                memcpy(s, endword, clen);
                (*charlen)++;
                s += clen;
            } else {
                break;
            }
        }
        endword += clen;
    }

    /*
     * Add right padding spaces if next character isn't wildcard
     * meta-character.
     */
    if (!in_trailing_wildcard_meta) {
        if (RPADDING > 0) {
            *s++ = ' ';
            (*charlen)++;
            if (RPADDING > 1) {
                *s++ = ' ';
                (*charlen)++;
            }
        }
    }

    *bytelen = s - buf;
    return endword;
}

/*
 * Generates trigrams for wildcard search string.
 *
 * Returns array of trigrams that must occur in any string that matches the
 * wildcard string.  For example, given pattern "a%bcd%" the trigrams
 * " a", "bcd" would be extracted.
 */
TRGM* generate_wildcard_trgm(const char* str, int slen)
{
    TRGM* trg = NULL;
    char *buf = NULL;
    char *buf2 = NULL;
    trgm* tptr = NULL;
    int len, charlen, bytelen;
    const char* eword = NULL;

    trg = (TRGM*)palloc(TRGMHDRSIZE + sizeof(trgm) * (slen / 2 + 1) * 3);
    trg->flag = ARRKEY;
    SET_VARSIZE(trg, TRGMHDRSIZE);

    if (slen + LPADDING + RPADDING < 3 || slen == 0) {
        return trg;
    }

    tptr = GETARR(trg);

    buf = (char*)palloc(sizeof(char) * (slen + 4));

    /*
     * Extract trigrams from each substring extracted by get_wildcard_part.
     */
    eword = str;
    while ((eword = get_wildcard_part(eword, slen - (eword - str), buf, &bytelen, &charlen)) != NULL) {
#ifdef IGNORECASE
        buf2 = lowerstr_with_len(buf, bytelen);
        bytelen = strlen(buf2);
#else
        buf2 = buf;
#endif

        /*
         * count trigrams
         */
        tptr = make_trigrams(tptr, buf2, bytelen, charlen);
#ifdef IGNORECASE
        pfree(buf2);
#endif
    }

    pfree(buf);

    if ((len = tptr - GETARR(trg)) == 0)
        return trg;

    /*
     * Make trigrams unique.
     */
    if (len > 0) {
        qsort((void*)GETARR(trg), len, sizeof(trgm), comp_trgm);
        len = unique_array(GETARR(trg), len);
    }

    SET_VARSIZE(trg, CALCGTSIZE(ARRKEY, len));

    return trg;
}

uint32 trgm2int(trgm* ptr)
{
    uint32 val = 0;

    val |= *(((unsigned char*)ptr));
    val <<= 8;
    val |= *(((unsigned char*)ptr) + 1);
    val <<= 8;
    val |= *(((unsigned char*)ptr) + 2);

    return val;
}

Datum show_trgm(PG_FUNCTION_ARGS)
{
    text* in = PG_GETARG_TEXT_P(0);
    TRGM* trg = NULL;
    Datum* d = NULL;
    ArrayType* a = NULL;
    trgm* ptr = NULL;
    int i;
    const int bufsize = 12;

    trg = generate_trgm(VARDATA(in), VARSIZE(in) - VARHDRSZ);
    d = (Datum*)palloc(sizeof(Datum) * (1 + ARRNELEM(trg)));

    for (i = 0, ptr = GETARR(trg); i < ARRNELEM(trg); i++, ptr++) {
        text* item = (text*)palloc(VARHDRSZ + Max(bufsize, pg_database_encoding_max_length() * 3));

        if (pg_database_encoding_max_length() > 1 && !ISPRINTABLETRGM(ptr)) {
            int rc = snprintf_s(VARDATA(item), bufsize, bufsize - 1, "0x%06x", trgm2int(ptr));
            securec_check_ss(rc, "", "");
            SET_VARSIZE(item, VARHDRSZ + strlen(VARDATA(item)));
        } else {
            SET_VARSIZE(item, VARHDRSZ + 3);
            CPTRGM(VARDATA(item), ptr);
        }
        d[i] = PointerGetDatum(item);
    }

    a = construct_array(d, ARRNELEM(trg), TEXTOID, -1, false, 'i');

    for (i = 0; i < ARRNELEM(trg); i++)
        pfree(DatumGetPointer(d[i]));

    pfree(d);
    pfree(trg);
    PG_FREE_IF_COPY(in, 0);

    PG_RETURN_POINTER(a);
}

float4 cnt_sml(TRGM* trg1, TRGM* trg2)
{
    trgm *ptr1, *ptr2;
    int count = 0;
    int len1, len2;

    ptr1 = GETARR(trg1);
    ptr2 = GETARR(trg2);

    len1 = ARRNELEM(trg1);
    len2 = ARRNELEM(trg2);

    /* explicit test is needed to avoid 0/0 division when both lengths are 0 */
    if (len1 <= 0 || len2 <= 0)
        return (float4)0.0;

    while (ptr1 - GETARR(trg1) < len1 && ptr2 - GETARR(trg2) < len2) {
        int res = CMPTRGM(ptr1, ptr2);

        if (res < 0)
            ptr1++;
        else if (res > 0)
            ptr2++;
        else {
            ptr1++;
            ptr2++;
            count++;
        }
    }

#ifdef DIVUNION
    return ((float4)count) / ((float4)(len1 + len2 - count));
#else
    return ((float4)count) / ((float4)((len1 > len2) ? len1 : len2));
#endif
}

/*
 * Returns whether trg2 contains all trigrams in trg1.
 * This relies on the trigram arrays being sorted.
 */
bool trgm_contained_by(TRGM* trg1, TRGM* trg2)
{
    trgm *ptr1, *ptr2;
    int len1, len2;

    ptr1 = GETARR(trg1);
    ptr2 = GETARR(trg2);

    len1 = ARRNELEM(trg1);
    len2 = ARRNELEM(trg2);

    while (ptr1 - GETARR(trg1) < len1 && ptr2 - GETARR(trg2) < len2) {
        int res = CMPTRGM(ptr1, ptr2);

        if (res < 0)
            return false;
        else if (res > 0)
            ptr2++;
        else {
            ptr1++;
            ptr2++;
        }
    }
    if (ptr1 - GETARR(trg1) < len1)
        return false;
    else
        return true;
}

Datum similarity(PG_FUNCTION_ARGS)
{
    text* in1 = PG_GETARG_TEXT_P(0);
    text* in2 = PG_GETARG_TEXT_P(1);
    TRGM *trg1, *trg2;
    float4 res;

    trg1 = generate_trgm(VARDATA(in1), VARSIZE(in1) - VARHDRSZ);
    trg2 = generate_trgm(VARDATA(in2), VARSIZE(in2) - VARHDRSZ);

    res = cnt_sml(trg1, trg2);

    pfree(trg1);
    pfree(trg2);
    PG_FREE_IF_COPY(in1, 0);
    PG_FREE_IF_COPY(in2, 1);

    PG_RETURN_FLOAT4(res);
}

Datum similarity_dist(PG_FUNCTION_ARGS)
{
    float4 res = DatumGetFloat4(DirectFunctionCall2(similarity, PG_GETARG_DATUM(0), PG_GETARG_DATUM(1)));

    PG_RETURN_FLOAT4(1.0 - res);
}

Datum similarity_op(PG_FUNCTION_ARGS)
{
    float4 res = DatumGetFloat4(DirectFunctionCall2(similarity, PG_GETARG_DATUM(0), PG_GETARG_DATUM(1)));

    PG_RETURN_BOOL(res >= trgm_limit);
}
